{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (C) 2016 - 2019 Pinard Liu(liujianping-ok@163.com)\n",
    "\n",
    "https://www.cnblogs.com/pinard\n",
    "\n",
    "Permission given to modify the code as long as you keep this declaration at the top\n",
    "\n",
    "用hmmlearn学习隐马尔科夫模型HMM https://www.cnblogs.com/pinard/p/7001397.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('The ball picked:', 'red, white, red')\n",
      "('The hidden box', 'box3, box3, box3')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Python27\\lib\\site-packages\\IPython\\kernel\\__main__.py:31: VisibleDeprecationWarning: converting an array with ndim > 0 to an index will result in an error in the future\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from hmmlearn import hmm\n",
    "\n",
    "states = [\"box 1\", \"box 2\", \"box3\"]\n",
    "n_states = len(states)\n",
    "\n",
    "observations = [\"red\", \"white\"]\n",
    "n_observations = len(observations)\n",
    "\n",
    "start_probability = np.array([0.2, 0.4, 0.4])\n",
    "\n",
    "transition_probability = np.array([\n",
    "  [0.5, 0.2, 0.3],\n",
    "  [0.3, 0.5, 0.2],\n",
    "  [0.2, 0.3, 0.5]\n",
    "])\n",
    "\n",
    "emission_probability = np.array([\n",
    "  [0.5, 0.5],\n",
    "  [0.4, 0.6],\n",
    "  [0.7, 0.3]\n",
    "])\n",
    "\n",
    "model = hmm.MultinomialHMM(n_components=n_states)\n",
    "model.startprob_=start_probability\n",
    "model.transmat_=transition_probability\n",
    "model.emissionprob_=emission_probability\n",
    "\n",
    "seen = np.array([[0,1,0]]).T\n",
    "logprob, box = model.decode(seen, algorithm=\"viterbi\")\n",
    "print(\"The ball picked:\", \", \".join(map(lambda x: observations[x], seen)))\n",
    "print(\"The hidden box\", \", \".join(map(lambda x: states[x], box)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('The ball picked:', 'red, white, red')\n",
      "('The hidden box', 'box3, box3, box3')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Python27\\lib\\site-packages\\IPython\\kernel\\__main__.py:2: VisibleDeprecationWarning: converting an array with ndim > 0 to an index will result in an error in the future\n",
      "  from IPython.kernel.zmq import kernelapp as app\n"
     ]
    }
   ],
   "source": [
    "box2 = model.predict(seen)\n",
    "print(\"The ball picked:\", \", \".join(map(lambda x: observations[x], seen)))\n",
    "print(\"The hidden box\", \", \".join(map(lambda x: states[x], box2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-2.03854530992\n"
     ]
    }
   ],
   "source": [
    "print model.score(seen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[  1.31697183e-12   3.07279825e-15   1.00000000e+00]\n",
      "[[  1.48161608e-01   1.50331031e-01   7.01507361e-01]\n",
      " [  2.19355550e-01   2.31744543e-01   5.48899907e-01]\n",
      " [  4.55650613e-01   5.44338342e-01   1.10450680e-05]]\n",
      "[[  2.19032398e-01   7.80967602e-01]\n",
      " [  1.15445252e-01   8.84554748e-01]\n",
      " [  9.99192283e-01   8.07717131e-04]]\n",
      "-6.51866416797\n",
      "[  9.99999997e-01   3.19111963e-09   2.84005362e-23]\n",
      "[[  8.42932318e-09   4.12422449e-02   9.58757747e-01]\n",
      " [  1.35449734e-01   5.71715372e-01   2.92834894e-01]\n",
      " [  5.77139011e-01   9.57292050e-02   3.27131784e-01]]\n",
      "[[  9.99987996e-01   1.20035621e-05]\n",
      " [  5.22916524e-01   4.77083476e-01]\n",
      " [  1.51107630e-01   8.48892370e-01]]\n",
      "-6.61149770824\n",
      "[  1.00000000e+00   1.30722027e-10   3.08257576e-12]\n",
      "[[  1.11200351e-05   4.22408201e-01   5.77580679e-01]\n",
      " [  6.44472673e-01   1.69915233e-01   1.85612094e-01]\n",
      " [  5.92638365e-01   1.91056314e-01   2.16305321e-01]]\n",
      "[[  9.99515188e-01   4.84812380e-04]\n",
      " [  1.87229564e-01   8.12770436e-01]\n",
      " [  1.49192055e-01   8.50807945e-01]]\n",
      "-6.54281680533\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from hmmlearn import hmm\n",
    "\n",
    "states = [\"box 1\", \"box 2\", \"box3\"]\n",
    "n_states = len(states)\n",
    "\n",
    "observations = [\"red\", \"white\"]\n",
    "n_observations = len(observations)\n",
    "model2 = hmm.MultinomialHMM(n_components=n_states, n_iter=20, tol=0.01)\n",
    "X2 = np.array([[0,1,0,1],[0,0,0,1],[1,0,1,1]])\n",
    "model2.fit(X2)\n",
    "print model2.startprob_\n",
    "print model2.transmat_\n",
    "print model2.emissionprob_\n",
    "print model2.score(X2)\n",
    "model2.fit(X2)\n",
    "print model2.startprob_\n",
    "print model2.transmat_\n",
    "print model2.emissionprob_\n",
    "print model2.score(X2)\n",
    "model2.fit(X2)\n",
    "print model2.startprob_\n",
    "print model2.transmat_\n",
    "print model2.emissionprob_\n",
    "print model2.score(X2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "startprob = np.array([0.6, 0.3, 0.1, 0.0])\n",
    "# The transition matrix, note that there are no transitions possible\n",
    "# between component 1 and 3\n",
    "transmat = np.array([[0.7, 0.2, 0.0, 0.1],\n",
    "                     [0.3, 0.5, 0.2, 0.0],\n",
    "                     [0.0, 0.3, 0.5, 0.2],\n",
    "                     [0.2, 0.0, 0.2, 0.6]])\n",
    "# The means of each component\n",
    "means = np.array([[0.0,  0.0],\n",
    "                  [0.0, 11.0],\n",
    "                  [9.0, 10.0],\n",
    "                  [11.0, -1.0]])\n",
    "# The covariance of each component\n",
    "covars = .5 * np.tile(np.identity(2), (4, 1, 1))\n",
    "\n",
    "# Build an HMM instance and set parameters\n",
    "model3 = hmm.GaussianHMM(n_components=4, covariance_type=\"full\")\n",
    "\n",
    "# Instead of fitting it from the data, we directly set the estimated\n",
    "# parameters, the means and covariance of the components\n",
    "model3.startprob_ = startprob\n",
    "model3.transmat_ = transmat\n",
    "model3.means_ = means\n",
    "model3.covars_ = covars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Python27\\lib\\site-packages\\sklearn\\utils\\deprecation.py:70: DeprecationWarning: Function log_multivariate_normal_density is deprecated; The function log_multivariate_normal_density is deprecated in 0.18 and will be removed in 0.20.\n",
      "  warnings.warn(msg, category=DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "seen = np.array([[1.1,2.0],[-1,2.0],[3,7]])\n",
    "logprob, state = model.decode(seen, algorithm=\"viterbi\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 0 1]\n"
     ]
    }
   ],
   "source": [
    "print state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-41.1211281377\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Python27\\lib\\site-packages\\sklearn\\utils\\deprecation.py:70: DeprecationWarning: Function log_multivariate_normal_density is deprecated; The function log_multivariate_normal_density is deprecated in 0.18 and will be removed in 0.20.\n",
      "  warnings.warn(msg, category=DeprecationWarning)\n",
      "C:\\Python27\\lib\\site-packages\\hmmlearn\\base.py:459: RuntimeWarning: divide by zero encountered in log\n",
      "  np.log(self.startprob_),\n",
      "C:\\Python27\\lib\\site-packages\\hmmlearn\\base.py:460: RuntimeWarning: divide by zero encountered in log\n",
      "  np.log(self.transmat_),\n"
     ]
    }
   ],
   "source": [
    "print model3.score(seen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
